nlp_architect.models.gnmt_model.GNMTModel

class nlp_architect.models.gnmt_model.GNMTModel(hparams, mode, iterator, source_vocab_table, target_vocab_table, reverse_target_vocab_table=None, scope=None, extra_args=None)[source]

Sequence-to-sequence dynamic model with GNMT attention architecture with sparsity policy support.

__init__(hparams, mode, iterator, source_vocab_table, target_vocab_table, reverse_target_vocab_table=None, scope=None, extra_args=None)[source]

Create the model.

Parameters:
  • hparams – Hyperparameter configurations.
  • mode – TRAIN | EVAL | INFER
  • iterator – Dataset Iterator that feeds data.
  • source_vocab_table – Lookup table mapping source words to ids.
  • target_vocab_table – Lookup table mapping target words to ids.
  • reverse_target_vocab_table – Lookup table mapping ids to target words. Only required in INFER mode. Defaults to None.
  • scope – scope of the model.
  • extra_args – model_helper.ExtraArgs, for passing customizable functions.

Methods

__init__(hparams, mode, iterator, …[, …]) Create the model.
build_encoder_states([include_embeddings]) Stack encoder states and return tensor [batch, length, layer, size].
build_graph(hparams[, scope]) Subclass must implement this method.
decode(sess) Decode a batch.
eval(sess) Execute eval graph.
get_max_time(tensor)
infer(sess)
init_embeddings(hparams, scope) Init embeddings.
train(sess) Execute train graph.
build_encoder_states(include_embeddings=False)

Stack encoder states and return tensor [batch, length, layer, size].

build_graph(hparams, scope=None)

Subclass must implement this method.

Creates a sequence-to-sequence model with dynamic RNN decoder API. :param hparams: Hyperparameter configurations. :param scope: VariableScope for the created subgraph; default “dynamic_seq2seq”.

Returns:A tuple of the form (logits, loss_tuple, final_context_state, sample_id), where:
logits: float32 Tensor [batch_size x num_decoder_symbols]. loss: loss = the total loss / batch_size. final_context_state: the final state of decoder RNN. sample_id: sampling indices.
Raises:ValueError – if encoder_type differs from mono and bi, or attention_option is not (luong | scaled_luong | bahdanau | normed_bahdanau).
decode(sess)

Decode a batch.

Parameters:sess – tensorflow session to use.
Returns:
A tuple consiting of outputs, infer_summary.
outputs: of size [batch_size, time]
eval(sess)

Execute eval graph.

get_max_time(tensor)
infer(sess)
init_embeddings(hparams, scope)

Init embeddings.

train(sess)

Execute train graph.